Skip to content

Allow bucket reshuffling with DreamBooth caches#13712

Open
azolotenkov wants to merge 2 commits intohuggingface:mainfrom
azolotenkov:feat-bucket-epoch-reshuffle-caching
Open

Allow bucket reshuffling with DreamBooth caches#13712
azolotenkov wants to merge 2 commits intohuggingface:mainfrom
azolotenkov:feat-bucket-epoch-reshuffle-caching

Conversation

@azolotenkov
Copy link
Copy Markdown
Contributor

What does this PR do?

Allow DreamBooth bucket batches to reshuffle each epoch while keeping cached latents and custom-caption prompt embeddings aligned.

After #13353, bucket batches with cached latents/custom captions were kept in stable step order because caches were indexed by dataloader step. This fixes the underlying limitation by indexing cached latents and prompt embeddings by dataset sample index instead. The training dataloader can then reshuffle bucket batches each epoch without reading the wrong cached tensors.

The cache precompute pass now uses a non-dropping cache dataloader, so every sample that can appear in a later reshuffled training epoch has a cache entry.

This also avoids mutating static prompt embeddings inside the training loop. Each step now derives repeated prompt/text embeddings from the original static tensors, which keeps prior-preservation runs with multiple steps stable.

Tested:

Klein smoke tests with hf-internal-testing/tiny-flux2-klein:

  • static prompt, no prior, no cache
  • static prompt, no prior, --cache_latents
  • custom captions, no prior, no latent cache
  • custom captions, no prior, --cache_latents
  • static prompt + prior preservation, no cache
  • static prompt + prior preservation, --cache_latents
  • custom captions + prior preservation, no latent cache
  • custom captions + prior preservation, --cache_latents
  • custom captions + prior preservation + --cache_latents, crossing an epoch boundary with max_train_steps=7

Flux2 smoke tests with hf-internal-testing/tiny-flux2 using the standard tiny-model settings:

  • no prior preservation, no cache, train_batch_size=1, max_train_steps=2
  • prior preservation, train_batch_size=2, max_train_steps=2
  • prior preservation + --cache_latents, train_batch_size=2, max_train_steps=3

Before submitting

Who can review?

@sayakpaul

Copilot AI review requested due to automatic review settings May 10, 2026 19:42
@github-actions github-actions Bot added examples size/L PR with diff > 200 LOC labels May 10, 2026
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Enables epoch-wise reshuffling of DreamBooth bucketed batches in the Flux2 DreamBooth LoRA example scripts while keeping cached latents and (custom-caption) prompt/text embeddings correctly aligned by switching caches from step-indexing to dataset-sample indexing.

Changes:

  • Add per-sample index to dataset items and propagate it through collate_fn so caches can be keyed by sample index rather than dataloader step.
  • Rework latent/prompt-embedding caching to precompute via a non-dropping cache dataloader and store per-sample cached tensors.
  • Update BucketBatchSampler to reshuffle indices/batches on each __iter__() call (epoch reshuffle) while keeping __len__ stable.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
examples/dreambooth/train_dreambooth_lora_flux2.py Switch cache alignment to sample-indexing and re-enable epoch reshuffling for bucket batches.
examples/dreambooth/train_dreambooth_lora_flux2_klein.py Same cache-by-index + epoch reshuffle changes for the Klein variant of the training script.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread examples/dreambooth/train_dreambooth_lora_flux2.py
Comment thread examples/dreambooth/train_dreambooth_lora_flux2_klein.py
@azolotenkov azolotenkov force-pushed the feat-bucket-epoch-reshuffle-caching branch from 6cf74c0 to b022936 Compare May 10, 2026 19:56
@sayakpaul
Copy link
Copy Markdown
Member

@claude can you do a review here?

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 11, 2026

Claude finished @sayakpaul's task in 5m 3s —— View job


Review: Allow bucket reshuffling with DreamBooth caches

  • Read review guidelines (AGENTS.md)
  • Read the full changes in both files
  • Trace cache precompute and training-loop retrieval paths
  • Post review feedback

Overall: The approach is sound. Switching from step-indexed to sample-indexed caching correctly decouples cache lookups from batch ordering, enabling epoch reshuffling without misaligned cache reads. The non-dropping cache dataloader (drop_last=False) ensures every sample has a cache entry. Both files are consistent with one another (modulo Klein-specific differences: no remote_text_encoder path, Flux2KleinPipeline vs Flux2Pipeline).

No blocking issues found. A few observations and suggestions below.


Correctness trace

I traced the full flow and confirmed the design holds:

  1. BucketBatchSampler generates indices 0..len(pixel_values)-1, which equals num_instance_images (already includes repeats). Since cache lists are sized [None] * num_instance_images, all indices are in-bounds.

  2. Cache precompute (drop_last=False) visits every sample. Training dataloader (drop_last=True) may drop remainder batches, but those samples still have cache entries. Reshuffling changes which samples land in the remainder, but the cache has full coverage regardless.

  3. Prior-preservation alignment: collate_fn orders batches as [inst1..instB, class1..classB]. The cache retrieval reconstructs this same layout — instance latents first, then class latents. Prompt handling follows the same pattern (custom prompts concat'd with repeated class prompts). This is consistent.

  4. Static prompt path: repeat_interleave from the immutable static_prompt_embeds/static_text_ids avoids mutating shared tensors across steps — this is a nice improvement over the prior pattern.


Suggestions (non-blocking)

  1. Consider an assertion after the cache loop (train_dreambooth_lora_flux2.py:1645, train_dreambooth_lora_flux2_klein.py:1586)

    After caching, there's no validation that every slot was filled. If a future change accidentally breaks the sampler, you'd get a silent None in the training loop (likely surfacing as a confusing torch.cat error). A quick sanity check would surface the real cause:

    if args.cache_latents:
        assert all(x is not None for x in instance_latents_cache), "Latent cache has unfilled entries"
  2. Unused batch["instance_prompts"] in training loop

    During training, batch["instance_prompts"] is never read — prompts come from prompt_embeds_cache (custom caption path) or static_prompt_embeds (static path). It's harmless but does mean the string data travels through the dataloader for no reason. Not worth changing now, just noting for awareness.

  3. Minor: collate_fn initializes both instance_prompts and prompts from the same source (train_dreambooth_lora_flux2.py:963-964, train_dreambooth_lora_flux2_klein.py:958-959)

    instance_prompts = [example["instance_prompt"] for example in examples]
    prompts = [example["instance_prompt"] for example in examples]

    These are separate list objects (so the prompts += class_prompts mutation is safe), but at first glance it reads like a duplication bug. A short inline comment explaining that instance_prompts stays as-is for cache keying while prompts gets extended for prior preservation would help future readers.


Summary

The core design (sample-indexed caches + drop_last=False precompute + reshuffling __iter__) is correct and well-tested across the configurations listed in the PR description. LGTM with the optional suggestions above.


@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

examples size/L PR with diff > 200 LOC

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants